rm(list=ls())
library(FactorHet)
library(tidyverse)
library(ggtern)
library(ggpubr)

###############Read in model output###################
estimated_models <- readRDS('main_models/out_BIC.RDS')
est_mbo3 <- estimated_models[[3]]
est_mbo2 <- estimated_models[[2]]

# Set baselines for plotting

baselines <- list()
baselines[["Gender"]] <- c("female")
baselines[["Ed"]] <- c("No formal")
baselines[["Lang"]] <- c("Fluent")
baselines[["Country"]] <- c("India")
baselines[["Job"]] <- c("Janitor")
baselines[["Exp"]] <- c("No job training or prior experience")
baselines[["Plans"]] <- c("Will look for work after arriving in the U.S.")
baselines[["Reason"]] <- c("Family")
baselines[["Trips"]] <- c("Never been to the U.S.")

##Data formating for plots

clean_factor <- list(
  orig = c("Country", "Ed","Exp", "Gender", "Job", "Lang", "Plans",  
           "Reason","Trips"),
  clean = c("Country", "Education","Job exp", "Gender", "Job", "Language", "Plans",  
            "Reason","Trips")
)

clean_level <- list(
  orig = c(">5 years", "1-2 years", "2yCol", "3-5 years", "6 months with family", 
           "Broken", "Child care provider","China","Col","Computer programmer",
           "Construction worker", "Doctor", "Once w/o authorization", "Family",
           "Female", "Financial analyst", "Fluent", "France", "Gardener", "Germany",
           "GradDeg", "Grade4", "Grade8", "Has contract", "HS", "India", "Interpreter",
           "Iraq", "Janitor", "Job", "Male", "Mexico", "Multiple times with visa", 
           "Never been", "No contract, had interviews", "No formal", "No plans", "None",      
           "Nurse", "Once with visa", "Persecution", "Philippines", "Poland", "Research scientist",
           "Somalia", "Sudan", "Teacher", "Unable", "Waiter", "Will look after arrival"),
  clean = c(">5 years", "1-2 years", "2 year col", "3-5 years", "6 months with family", 
            "Broken", "Child care provider","China","Col","Computer programmer",
            "Construction worker", "Doctor", "Once w/o authorization", "Family",
            "Female", "Financial analyst", "Fluent", "France", "Gardener", "Germany",
            "Grad degree", "Grade 4", "Grade 8", "Has contract", "High school", "India", "Interpreter",
            "Iraq", "Janitor", "Job", "Male", "Mexico", "Multiple times with visa", 
            "Never been", "No contract, had interviews", "No formal", "No plans", "None",      
            "Nurse", "Once with visa", "Persecution", "Philippines", "Poland", "Research scientist",
            "Somalia", "Sudan", "Teacher", "Unable", "Waiter", "Will look after arrival")
)

clean_fmt_level <- list(
  orig = c(">5 years", "1-2 years", "2yCol", "3-5 years", "6 months with family", 
           "Broken", "Child care provider","China","Col","Computer programmer",
           "Construction worker", "Doctor", "Once w/o authorization", "Family",
           "Female", "Financial analyst", "Fluent", "France", "Gardener", "Germany",
           "GradDeg", "Grade4", "Grade8", "Has contract", "HS", "India", "Interpreter",
           "Iraq", "Janitor", "Job", "Male", "Mexico", "Multiple times with visa", 
           "Never been", "No contract, had interviews", "No formal", "No plans", "None",      
           "Nurse", "Once with visa", "Persecution", "Philippines", "Poland", "Research scientist",
           "Somalia", "Sudan", "Teacher", "Unable", "Waiter", "Will look after arrival"),
  clean = c(">5 years", "1-2 years", "2 year college", "3-5 years", "6 months with family", 
            "Broken", "Child care provider","China","College","Computer programmer",
            "Construction worker", "Doctor", "Once w/o authorization", "Family",
            "Female", "Financial analyst", "Fluent", "France", "Gardener", "Germany",
            "Grad degree", "Grade 4", "Grade 8", "Has contract", "High school", "India", "Interpreter",
            "Iraq", "Janitor", "Job", "Male", "Mexico", "Multiple times with visa", 
            "Never been", "No contract, had interviews", "No formal", "No plans", "None",      
            "Nurse", "Once with visa", "Persecution", "Philippines", "Poland", "Research scientist",
            "Somalia", "Sudan", "Teacher", "Unable", "Waiter", "Will look after arrival")
)

clean_mfx_plot <- list(
  orig = c("scale_hisp_prej_flip","census_div(Few Immigrants)" ,"census_div(Many Immigrants, Majority Hispanic)",
           "ppEducat(Bachelor's degree or higher)", "ppEducat(Some college)",  "ppEducat(High school)",
           "party_ID(Strong Democrat)", "party_ID(Not Strong Democrat)", "party_ID(Leans Democrat)",
           "party_ID(Undecided/Independent/Other)", "party_ID(Leans Republican)", "party_ID(Not Strong Republican)",
           "ppEthm(White, Non-Hispanic)", "ppEthm(White, Non-Hispanic)","ppEthm(Black, Non-Hispanic)",
           "ppEthm(Black, Non-Hispanic)", "ppEthm(Other, Non-Hispanic)", "ppEthm(Other, Non-Hispanic)",
           "ppEthm(Hispanic)", "ppEthm(2+ Races, Non-Hispanic)"),
  clean = c("Hispanic prejudice score", "ZIP(Few immigrants)", "ZIP(Many immigrants, majority Hisp)", 
            "Ed(Bachelor's or higher)", "Ed(Some college)", "Ed(High school)", "Party ID(Strong Democrat)",
            "Party ID(Not strong Democrat)", "Party ID(Leans Democrat)", "Party ID(Undecided/Indep/Other)",
            "Party ID(Leans Republican)", "Party ID(Not strong Republican)",
            "White, Non-Hispanic", "White, Non-Hispanic","Black, Non-Hispanic", "Black, Non-Hispanic",
            "Other, Non-Hispanic", "Other, Non-Hispanic", "Hispanic", "2+ Races, Non-Hispanic")
)

#Categorical moderators in bar charts
clean_bar <- list(
  orig = c("Strong Republican", "Not Strong Republican", "Leans Republican",
           "Undecided/Independent/Other", "Leans Democrat", "Not Strong Democrat",
           "Strong Democrat", "Less than high school", "High school",
           "Some college", "Bachelor's degree or higher", "Many Immigrants, Majority Not Hispanic",
           "Many Immigrants, Majority Hispanic", "Few Immigrants" ),
  clean = c("Strong Republican", "Not strong Republican","Leans Republican", 
            "Undecided/Indep/Other", "Leans Democrat", "Not strong Democrat",
            "Strong Democrat", "< High school", "High school",
            "Some college", "Bachelor's or higher", "Many immigrants, majority not Hisp",
            "Many immigrants, majority Hisp", "Few immigrants" 
  )
)

##########################Analysis: 3 Groups################################

##Posterior predictive weights
post_pred_3_grp<-est_mbo3$posterior$posterior_predictive
post_pred_3_grpa <- post_pred_3_grp %>% pivot_longer(
  cols = starts_with("group_"),
  names_to = "Group",
  names_prefix = "group_",
  values_to = "post_weights",
  values_drop_na = TRUE
)

post_pred_3_mean<-c(mean(subset(post_pred_3_grpa, Group==1)$post_weights),
                    mean(subset(post_pred_3_grpa, Group==2)$post_weights),
                    mean(subset(post_pred_3_grpa, Group==3)$post_weights))


###Plots

##Get marginal AME calculations
est_AME3 <- AME(est_mbo3) 

##Data formatting for plots
data_ame3 <- est_AME3$data %>% ungroup %>%
  mutate(
    factor = recode(factor, !!! setNames(clean_factor$clean, clean_factor$orig)),
    level = recode(level, !!! setNames(clean_level$clean, clean_level$orig)),
    fmt_level = recode(fmt_level, !!! setNames(clean_fmt_level$clean, clean_fmt_level$orig))
  )

###AME plot for 3 groups
ame3_plot<- ggplot(data_ame3,
                   aes(x=fmt_level, ymin=marginal_effect-1.96*sqrt(var),
                       ymax=marginal_effect+1.96 * sqrt(var),
                       y=marginal_effect)) +
  geom_hline(aes(yintercept=0), linetype = 'dashed') +
  geom_point(aes(col = factor(group), pch = factor(group))) +
  geom_errorbar(aes(col = factor(group))) +
  geom_point(data = data_ame3 %>% filter(baseline)) +
  coord_flip() + facet_grid(factor ~ paste0('Group ', group, ": ", round(post_pred_3_mean[group], 3)*100, "%"), scale = 'free_y', space = 'free_y', switch = 'y', 
                            labeller = label_wrap_gen(width = 10, multi_line = TRUE)) +
  theme_bw(base_size = 10) +
  theme(strip.text.y.left = element_text(angle = 0), 
        panel.spacing = unit(0.1, 'lines'), strip.placement = 'outside',
        axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1)) +
  #axis.text.x = element_text(size = 4.5)) +
  xlab('Factor') + ylab('Effect') +
  ylim(-0.4, 0.4) +
  theme(legend.position = 'none')


#### Calculate AMIEs
AMIE_country_job3 <- AMIE(object = est_mbo3, baseline = list('Country' = 'Germany', 'Job' = 'Janitor'))
AMIE_country_ed3 <- AMIE(object = est_mbo3, baseline = list('Country' = 'Germany', 'Ed' = 'No formal'))
AMIE_country_reason3 <- AMIE(object = est_mbo3, baseline = list('Country' = 'Germany', 'Reason' = 'Family'))
AMIE_ed_job3 <- AMIE(object = est_mbo3, baseline = list('Ed' = 'GradDeg', 'Job' = 'Doctor'))


## Moderator plots
est_mfx3 <-  margeff_moderators(est_mbo3)
data_mfx3 <- est_mfx3$plot$data
data_mfx3$num_groups <- rep("3 groups", nrow(data_mfx3))

##Other plots of moderators

est_mod_plots3 <- posterior_by_moderators(est_mbo3)

#Categorical moderators in bar charts

mod_bar_col_sum3 <- est_mod_plots3$discrete_bar
mod_bar_col_sum3$data$variable <- recode(
  mod_bar_col_sum3$data$variable,
  !!!setNames(c("Party ID", "Education", "ZIP diversity", "Ethnicity"),
              c("party_ID", "ppEducat", "census_div", "ppEthm"))
)
mod_bar_col_sum3$data$value <- recode(
  mod_bar_col_sum3$data$value, 
  !!! setNames(clean_bar$clean, clean_bar$orig)
)
data_bar_col3 <- mod_bar_col_sum3$data
data_bar_col3$num_groups<-rep("3 groups", nrow(mod_bar_col_sum3$data))

#Continuous moderators in boxplots
mod_boxplot_cont3 <- est_mod_plots3$continuous
mod_boxplot_cont3$data$variable <- recode(
  mod_boxplot_cont3$data$variable, 
  "scale_hisp_prej_flip" = "Hispanic Prej Scale"
)
data_cont3<-mod_boxplot_cont3$data
data_cont3$num_groups<-rep("3 groups", nrow(data_cont3))

#Check lambda value
est_mbo3$parameters$eff_lambda

##########################Analysis: 2 groups################################

##Posterior predictive weights
post_pred_2_grp<-est_mbo2$posterior$posterior_predictive
post_pred_2_grpa <- post_pred_2_grp %>% pivot_longer(
  cols = starts_with("group_"),
  names_to = "Group",
  names_prefix = "group_",
  values_to = "post_weights",
  values_drop_na = TRUE
)

post_pred_2_mean<-c(mean(subset(post_pred_2_grpa, Group==1)$post_weights),
                    mean(subset(post_pred_2_grpa, Group==2)$post_weights))


###Graphs
# Calculate marginal AMEs
est_AME2 <- AME(est_mbo2) 
#Format data for plotting
data_ame2 <- est_AME2$data %>% 
  mutate(
    factor = recode(factor, !!! setNames(clean_factor$clean, clean_factor$orig)),
    level = recode(level, !!! setNames(clean_level$clean, clean_level$orig)),
    fmt_level = recode(fmt_level, !!! setNames(clean_fmt_level$clean, clean_fmt_level$orig))
  )

ame2_plot<- ggplot(data_ame2,
                   aes(x=fmt_level, ymin=marginal_effect-1.96*sqrt(var),
                       ymax=marginal_effect+1.96 * sqrt(var),
                       y=marginal_effect)) +
  geom_hline(aes(yintercept=0), linetype = 'dashed') +
  geom_point(aes(col = factor(group), pch = factor(group))) +
  geom_errorbar(aes(col = factor(group))) +
  geom_point(data = data_ame2 %>% filter(baseline)) +
  coord_flip() + facet_grid(factor ~ paste0('Group ', group, ": ", round(post_pred_2_mean[group], 3)*100, "%"), scale = 'free_y', space = 'free_y', switch = 'y', 
                            labeller = label_wrap_gen(width = 10, multi_line = TRUE)) +
  theme_bw(base_size = 10) +
  theme(strip.text.y.left = element_text(angle = 0), 
        panel.spacing = unit(0.1, 'lines'), strip.placement = 'outside',
        axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1)) +
  xlab('Factor') + ylab('Effect') +
  ylim(-0.4, 0.4) +
  theme(legend.position = 'none')

#Get AMIEs
AMIE_country_job2 <- AMIE(object = est_mbo2, baseline = list('Country' = 'Germany', 'Job' = 'Janitor'))
AMIE_country_ed2 <- AMIE(object = est_mbo2, baseline = list('Country' = 'Germany', 'Ed' = 'No formal'))
AMIE_country_reason2 <- AMIE(object = est_mbo2, baseline = list('Country' = 'Germany', 'Reason' = 'Family'))
AMIE_ed_job2 <- AMIE(object = est_mbo2, baseline = list('Ed' = 'GradDeg', 'Job' = 'Doctor'))

##Moderator plots
est_mfx2 <- margeff_moderators(est_mbo2)
data_mfx2 <- est_mfx2$plot$data
data_mfx2$num_groups<-rep("2 groups", nrow(data_mfx2))

#Alternative moderator plots

est_mod_plots2 <- posterior_by_moderators(est_mbo2)

#Bar charts for categorical covariates

mod_bar_col_sum2<-est_mod_plots2$discrete_bar
mod_bar_col_sum2$data$variable <- recode(
  mod_bar_col_sum2$data$variable,
  !!! setNames(c("Party ID", "Education", "ZIP diversity", "Ethnicity"), 
               c("party_ID", "ppEducat", "census_div", "ppEthm"))
)
mod_bar_col_sum2$data$value <- recode(
  mod_bar_col_sum2$data$value, 
  !!! setNames(clean_bar$clean, clean_bar$orig)
)

data_bar_col2<-mod_bar_col_sum2$data
data_bar_col2$num_groups<-rep("2 groups", nrow(mod_bar_col_sum2$data))


#Boxplots for continuous covariates
mod_boxplot_cont2<-est_mod_plots2$continuous
mod_boxplot_cont2$data$variable <- recode(
  mod_boxplot_cont2$data$variable, 
  "scale_hisp_prej_flip" = "Hispanic Prej Scale"
)
data_cont2<-mod_boxplot_cont2$data
data_cont2$num_groups<-rep("2 groups", nrow(data_cont2))


#Check lambda value

est_mbo2$parameters$eff_lambda

################FIGURE 3##################
####Plots together
grid.arrange(ame2_plot, ame3_plot, ncol=2, widths=c(1, 1.2))

#Save plot
g <- arrangeGrob(ame2_plot, ame3_plot, ncol=2, widths=c(1, 1.2))
ggsave("figures/AME_plot.pdf", g, width = 8, height = 6)

################FIGURE A13##################
####Marginal means plots

est_MM_3 <- AME(est_mbo3, ignore_restrictions = TRUE, baseline = NA)

MM_3_plot <- est_MM_3$plot + 
  facet_grid(factor ~ paste0('Group ', group), 
    scale = 'free_y', space = 'free_y', switch = 'y', 
    labeller = label_wrap_gen(width = 10, multi_line = TRUE)) 

ggsave(plot = MM_3_plot, filename = 'figures/marginal_means.pdf', width = 8, height = 6)


################FIGURE 6##################

# Visualize the effect of moderators (i.e., marginal "effects" of moderators)

mfx_data<-rbind(data_mfx2, data_mfx3)
group.labs <- c("Group 1", "Group 2", "Group 3")

mfx_plot<- ggplot(mfx_data, aes(alpha = factor(sig))) + 
  geom_point(aes(x=fmt_name,y=baseline)) +
  geom_segment(aes(x=fmt_name,xend=variable,y=baseline,yend=changed),
               arrow = arrow(length =unit(0.03, 'npc'))) + 
  coord_flip() + facet_wrap(~num_groups+group, nrow = 1, labeller = labeller(group = group.labs) )+
  theme_bw() + ylab('Posterior Predictive Probability of Group Membership') + 
  xlab('Covariate') +
  scale_alpha_manual(values = c(0.25, 1), guide = 'none')+ 
  scale_x_discrete(breaks=clean_mfx_plot$orig,
                   labels=clean_mfx_plot$clean) 

ggsave(plot = mfx_plot, 
       filename = 'figures/mod_plot.pdf',
       width = 8, height = 6)


################FIGURE 5##################
#Combine alternative moderator plots

#Bar charts

data_bar_col<-rbind(data_bar_col2, data_bar_col3)

col_bar_plot <- ggplot(data=data_bar_col , aes(y=norm_weight, x=value, fill=group)) +
  geom_bar(stat="identity", position=position_dodge2(reverse = TRUE))+
  xlab("Respondent characteristics") + ylab("Proportion")+
  labs(fill="Group number")+
  facet_grid(variable~num_groups, scales="free") +
  theme(axis.text.x = element_text(), text = element_text(size = 12)) +
  coord_flip()


#Boxplots
data_cont<-rbind(data_cont2, data_cont3)
cont_plot <- ggplot(data=data_cont, 
                    aes(x=group, y=value, weight=plot_weight, fill=group)) +
  geom_boxplot()+
  labs(title="",
       y ="Hispanic prejudice \n score", x = "Group number")+
  coord_flip()+
  scale_x_discrete(limits = rev(levels(data_cont$group))) +
  facet_grid(~num_groups, scales="free", space = "free",
             labeller = label_wrap_gen(width = 10, multi_line = TRUE))+ 
  theme(legend.position = "none", axis.title.y = element_text(angle = 90),text = element_text(size = 12))

#Plot alternative moderators together
grid.arrange(cont_plot,
             col_bar_plot,
             widths=c(1,3),
             ncol=2)
#Save alternative moderators plot
g <- arrangeGrob(cont_plot,
                 col_bar_plot,
                 widths=c(1,3),
                 ncol=2)
ggsave("figures/mod_plot2.pdf", g, width = 9.21, height = 5.39)

# Some information on party ID discussed in main textFac
data_bar_col %>% 
  filter(variable == 'Party ID') %>%
  mutate(PID = case_when(
    grepl(value, pattern='Republican') ~ 'Rep',
    grepl(value, pattern='Democrat') ~ 'Dem',
    TRUE ~ 'Ind'
  )) %>% group_by(num_groups, group, PID) %>%
  summarize(norm_weight = sum(norm_weight)) %>% 
  pivot_wider(id_cols = c(num_groups, group),
              names_from = PID, values_from = 'norm_weight') %>%
  print

data_bar_col %>% 
  filter(variable == 'Party ID') %>%
  pivot_wider(id_cols = c(num_groups, group),
              names_from = value, values_from = 'norm_weight') %>%
  mutate(across(where(is.numeric), ~ round(., 2))) %>%
  data.frame %>% print

###Learn about changes in group posterior probabilities from 2 to 3 groups
#Combine group info
post_pred_both<-merge(post_pred_3_grp, post_pred_2_grp, by="group")

#Useful function for making a table
group_table_fun<-function(data_obj, grp_var_nam1, grp_var_nam2){
  table.dat <- data.frame(matrix(ncol = length(grp_var_nam1), nrow = length(grp_var_nam2)), row.names=grp_var_nam2)
  colnames(table.dat) <- grp_var_nam1
  for(i in grp_var_nam1){
    for(j in grp_var_nam2){
      table.dat[j,i]<-mean(data_obj[,i]*data_obj[,j])
    }
  }
  return(table.dat)
}


#Table based on posterior pred prob of being in groups in each analysis
out<-group_table_fun(data_obj=post_pred_both, grp_var_nam1=c("group_1.y", "group_2.y"),
                       grp_var_nam2=c("group_1.x", "group_2.x", "group_3"))

out %>% mutate(across(everything(), ~ . /sum(.))) %>% print


################FIGURE 4##################

###Ternary plot
mid <- 0.5

tern_plot <- ggtern(data = post_pred_both, aes(group_1.x, group_2.x, group_3, colour = group_1.y)) +
  geom_point(
    alpha = 0.5,
    size = 1
  ) +
  theme_showarrows() +   Tlab("Group\n2") + Llab("Group\n1") + Rlab("Group\n3") + 
  Tarrowlab("Group 2") + Larrowlab("Group 1") + Rarrowlab("Group 3")+ 
  theme(legend.position = "bottom", tern.axis.padding = unit(0.5, "line")) +  
  labs(colour = "Group 1 posterior predictive probability")+ 
  scale_color_gradient2(midpoint=mid, low="darkblue", mid="white",
                        high="darkred") 
ggsave("figures/ternary_plot.pdf", tern_plot, width = 6, height = 4)

#################Interaction effects#################
#Find largest AMIE
max(abs(AMIE_country_job2$data[[1]]$AMIE))
max(abs(AMIE_country_job3$data[[1]]$AMIE))
max(abs(AMIE_country_ed2$data[[1]]$AMIE))
max(abs(AMIE_country_ed3$data[[1]]$AMIE))
max(abs(AMIE_country_reason2$data[[1]]$AMIE))
max(abs(AMIE_country_reason3$data[[1]]$AMIE))
max(abs(AMIE_ed_job2$data[[1]]$AMIE))
max(abs(AMIE_ed_job3$data[[1]]$AMIE))

AMIE_ed_job3$data[[1]][which(abs(AMIE_ed_job3$data[[1]]$AMIE) == max(abs(AMIE_ed_job3$data[[1]]$AMIE))),]

#Get AME range for education and job
summary(subset(data_ame3, !baseline & factor == "Education")$marginal_effect)
summary(subset(data_ame3, !baseline & factor == "Job")$marginal_effect)


################FIGURE A16##################

###Plot largest

data_amie_ej3<-AMIE_ed_job3$data[[1]]
plot_amie_ej3 <- ggplot(data_amie_ej3 %>% filter(!baseline)) +
  geom_tile(aes(x=Ed, y=Job, 
                       fill =AMIE)) +
  facet_wrap(~paste0('Group ', group)) + theme_bw() +
  theme(legend.position = 'bottom', panel.grid = element_blank(),
        axis.text.x = element_text(hjust=1,vjust=0, angle = 90)) +
  labs(fill = 'AMIE') +
  guides(fill = guide_colourbar(label.theme = element_text(size = 8, 
                                                           hjust = 1, angle = 90)))

ggsave("figures/int_ed_job_append.pdf", plot_amie_ej3 , width = 6, height = 4)

################FIGURE A14##################

# Create appendix figure for pooled sample splits on main data
repeat_final_output <- readRDS("final_output/repeat_final_output.RDS")

median_repeat_2 <- repeat_final_output[[2]]$save_AME %>%
  filter(type == 'perm') %>% 
  group_by(factor, level, fmt_level, group) %>%
  summarize(marginal_effect = median(marginal_effect)) %>%
  pivot_wider(id_cols = c('factor', 'level', 'fmt_level'),
              names_from = 'group', values_from = 'marginal_effect')

median_repeat_3 <- repeat_final_output[[3]]$save_AME %>%
  filter(type == 'perm') %>% 
  group_by(factor, level, fmt_level, group) %>%
  summarize(marginal_effect = median(marginal_effect)) %>%
  pivot_wider(id_cols = c('factor', 'level', 'fmt_level'),
              names_from = 'group', values_from = 'marginal_effect')

# Need to permute these to align with the group labels used for the full data
# The permutation of the original 20 simulations aligns them with each other.

wide_original_ame3 <- data_ame3 %>% filter(!is.na(baseline)) %>%
  pivot_wider(id_cols = c('factor', 'level', 'fmt_level'),
              names_from = 'group', values_from = 'marginal_effect')
wide_original_ame2 <- data_ame2 %>% filter(!is.na(baseline)) %>%
  pivot_wider(id_cols = c('factor', 'level', 'fmt_level'),
              names_from = 'group', values_from = 'marginal_effect')

perm_repeat_3 <- FactorHet:::internal_align(
  as.matrix(wide_original_ame3[,-1:-3]),
  as.matrix(median_repeat_3[,-1:-3])
)

perm_repeat_2 <- FactorHet:::internal_align(
  as.matrix(wide_original_ame2[,-1:-3]),
  as.matrix(median_repeat_2[,-1:-3])
)


repeat_final_output[[2]]$save_AME <- repeat_final_output[[2]]$save_AME  %>%
  ungroup %>%
  mutate(group = match(group, perm_repeat_2)) %>%
  mutate(
    factor = recode(factor, !!! setNames(clean_factor$clean, clean_factor$orig)),
    level = recode(level, !!! setNames(clean_level$clean, clean_level$orig)),
    fmt_level = recode(fmt_level, !!! setNames(clean_fmt_level$clean, clean_fmt_level$orig))
  )


repeat_final_output[[3]]$save_AME <- repeat_final_output[[3]]$save_AME  %>%
  ungroup %>%
  mutate(group = match(group, perm_repeat_3)) %>%
  mutate(
    factor = recode(factor, !!! setNames(clean_factor$clean, clean_factor$orig)),
    level = recode(level, !!! setNames(clean_level$clean, clean_level$orig)),
    fmt_level = recode(fmt_level, !!! setNames(clean_fmt_level$clean, clean_fmt_level$orig))
  )

g_boxplot_2 <- ggplot(
  repeat_final_output[[2]]$save_AME %>% 
    filter(type == 'perm') %>% filter(!baseline),
  aes(x=fmt_level,y=marginal_effect)
) +
  geom_boxplot(outlier.colour = NA, coef = 0) +
  coord_flip(ylim = c(-0.4, 0.4)) +
  geom_hline(aes(yintercept=0)) +facet_grid(factor ~ paste0('Group ', group), 
                                            scales = 'free_y', space = 'free_y', switch = 'y', 
                                            labeller = label_wrap_gen()) +
  theme_bw(base_size = 8) +
  theme(strip.text.y.left = element_text(angle = 0), 
        panel.spacing = unit(0.1, 'lines'), strip.placement = 'outside') +
  xlab('Factor') + ylab('Effect')

g_boxplot_3 <- ggplot(
  repeat_final_output[[3]]$save_AME %>% 
    filter(type == 'perm') %>% filter(!baseline),
  aes(x=fmt_level,y=marginal_effect)
) +
  geom_boxplot(outlier.colour = NA, coef = 0) +
  coord_flip(ylim = c(-0.4, 0.4)) +
  geom_hline(aes(yintercept=0)) +facet_grid(factor ~ paste0('Group ', group), 
                                            scales = 'free_y', space = 'free_y', switch = 'y', 
                                            labeller = label_wrap_gen()) +
  theme_bw(base_size = 8) +
  theme(strip.text.y.left = element_text(angle = 0), 
        panel.spacing = unit(0.1, 'lines'), strip.placement = 'outside') +
  xlab('Factor') + ylab('Effect')

g_both <- ggarrange(g_boxplot_2, g_boxplot_3)
ggsave(g_both, filename = 'figures/app_repeat_hh_AME.pdf', width = 8.5, height = 6)


################FIGURE A15##################
# Estimate the average absolute effect for the moderators

absolute_AME_3 <- margeff_moderators(est_mbo3, abs_diff = TRUE)
absolute_AME_2 <- margeff_moderators(est_mbo2, abs_diff = TRUE)

mfx_moderator <- rbind(absolute_AME_3$data %>% mutate(num_group = '3 groups'), 
                       absolute_AME_2$data %>% mutate(num_group = '2 groups'))

g_mfx_abs <- ggplot(mfx_moderator) +
  geom_point(aes(x=fmt_name,y=mean)) +
  scale_x_discrete(breaks=clean_mfx_plot$orig,
                   labels=clean_mfx_plot$clean)  +
  facet_wrap(~num_group+paste0('Group ', K), nrow = 1) +
  theme_bw() + coord_flip() +
  geom_errorbar(aes(x=fmt_name,ymin=ll,ymax=ul)) +
  ylab('Average Absolute Change in Posterior Predictive Probability of Group Membership') + 
  xlab('Covariate')  +
  geom_hline(aes(yintercept=0), linetype='dashed') +
  geom_point(aes(x=fmt_name, y=abs(point_estimate_average)), col = 'red', pch = 8)


ggsave(plot = g_mfx_abs, 
       filename = 'figures/app_abs_mod.pdf',
       width = 8, height = 6)

########## TABLE A.3 ##############

# Get table of AIC and BIC for K=1:4

ic_main <- data.frame(
  BIC = sapply(estimated_models[1:4], BIC),
  AIC = sapply(readRDS('main_models/out_AIC.RDS')[1:4], AIC)
)
ic_out <- c(
  '$K=1$ & $K=2$ & $K=3$ & $K=4$ \\\\',
  paste0(paste(round(ic_main$BIC), collapse = ' & '), '\\\\'),
  '\\hline',
  '\\multicolumn{4}{c}{Optimizing AIC over $\\lambda$} \\\\',
  '$K=1$ & $K=2$ & $K=3$ & $K=4$ \\\\',
  paste0(paste(round(ic_main$AIC), collapse = ' & '), '\\\\ \\hline\\hline')
) 
writeLines(text = ic_out, con = 'figures/compare_ic_hh.tex')  
